{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Context-FID Score Presentation\n",
    "## Necessary packages and functions call\n",
    "\n",
    "- Context-FID score: A useful metric measures how well the the synthetic time series windows ”fit” into the local context of the time series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))\n",
    "from Utils.context_fid import Context_FID\n",
    "from Utils.metric_utils import display_scores\n",
    "from Utils.cross_correlation import CrossCorrelLoss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Loading\n",
    "\n",
    "Load original dataset and preprocess the loaded data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 5\n",
    "ori_data = np.load('../toy_exp/samples/sine_ground_truth_24_train.npy')\n",
    "# ori_data = np.load('../OUTPUT/{dataset_name}/samples/{dataset_name}_norm_truth_{seq_length}_train.npy')\n",
    "fake_data = np.load('../toy_exp/ddpm_fake_sines.npy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Context-FID Score\n",
    "\n",
    "- The Frechet Inception distance-like score is based on unsupervised time series embeddings. It is able to score the fit of the fixed length synthetic samples into their context of (often much longer) true time series.\n",
    "\n",
    "- The lowest scoring models correspond to the best performing models in downstream tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 0:  context-fid = 0.007570160590888924 \n",
      "\n",
      "Iter 1:  context-fid = 0.008166182104461794 \n",
      "\n",
      "Iter 2:  context-fid = 0.007970870497297756 \n",
      "\n",
      "Iter 3:  context-fid = 0.007697393798102688 \n",
      "\n",
      "Iter 4:  context-fid = 0.008525671090148759 \n",
      "\n",
      "Final Score:  0.007986055616179984 ± 0.00047287486643304236\n"
     ]
    }
   ],
   "source": [
    "context_fid_score = []\n",
    "\n",
    "for i in range(iterations):\n",
    "    context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]])\n",
    "    context_fid_score.append(context_fid)\n",
    "    print(f'Iter {i}: ', 'context-fid =', context_fid, '\\n')\n",
    "      \n",
    "display_scores(context_fid_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Correlational Score\n",
    "\n",
    "- The metric uses the absolute error of the auto-correlation estimator by real data and synthetic data as the metric to assess the temporal dependency.\n",
    "\n",
    "- For d > 1, it uses the l1-norm of the difference between cross correlation matrices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_choice(size, num_select=100):\n",
    "    select_idx = np.random.randint(low=0, high=size, size=(num_select,))\n",
    "    return select_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 0:  cross-correlation = 0.019852216581699826 \n",
      "\n",
      "Iter 1:  cross-correlation = 0.01816951370252664 \n",
      "\n",
      "Iter 2:  cross-correlation = 0.022373661672091448 \n",
      "\n",
      "Iter 3:  cross-correlation = 0.012407943886992933 \n",
      "\n",
      "Iter 4:  cross-correlation = 0.010309792931556355 \n",
      "\n",
      "Final Score:  0.01662262575497344 ± 0.006316425881906014\n"
     ]
    }
   ],
   "source": [
    "x_real = torch.from_numpy(ori_data)\n",
    "x_fake = torch.from_numpy(fake_data)\n",
    "\n",
    "correlational_score = []\n",
    "size = int(x_real.shape[0] / iterations)\n",
    "\n",
    "for i in range(iterations):\n",
    "    real_idx = random_choice(x_real.shape[0], size)\n",
    "    fake_idx = random_choice(x_fake.shape[0], size)\n",
    "    corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')\n",
    "    loss = corr.compute(x_fake[fake_idx, :, :])\n",
    "    correlational_score.append(loss.item())\n",
    "    print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\\n')\n",
    "\n",
    "display_scores(correlational_score)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PT3.8",
   "language": "python",
   "name": "pt3.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
